import asyncio
import shutil
from datetime import datetime
from pathlib import Path
from typing import Any

import torch
from torchdata.stateful_dataloader import StatefulDataLoader

from areal.api.alloc_mode import ParallelStrategy
from areal.api.cli_args import TrainEngineConfig
from areal.api.controller_api import DistributedBatch
from areal.api.engine_api import TrainEngine
from areal.api.io_struct import (
    AllocationMode,
    FinetuneSpec,
    SaveLoadMeta,
    WeightUpdateMeta,
)
from areal.api.scheduler_api import Job, Scheduler, Worker
from areal.api.workflow_api import RolloutWorkflow
from areal.controller.batch import DistributedBatchMemory
from areal.controller.rollout_controller import RolloutController
from areal.platforms import current_platform
from areal.utils import logging, name_resolve, names

logger = logging.getLogger(__name__)


class TrainController:
    """Controller for managing distributed training across multiple workers.

    This class orchestrates the lifecycle of training workers, handles data
    distribution across data-parallel groups, and provides a unified interface
    for training operations. It manages worker creation, engine initialization,
    and coordinates method calls across distributed workers.

    The controller automatically handles:
    - Worker creation and lifecycle management via scheduler
    - Data splitting across data-parallel groups
    - Result merging from multiple workers
    - Distributed training configuration (MASTER_ADDR, MASTER_PORT)
    """

    def __init__(
        self,
        train_engine: type[TrainEngine],
        config: TrainEngineConfig,
        scheduler: Scheduler,
    ):
        self.train_engine = train_engine
        self.config = config
        self.scheduler = scheduler

        self.alloc_mode: AllocationMode
        self.workers: list[Worker] = []
        # Boolean list indicating which workers are data-parallel heads
        # Only DP head workers receive data slices; others get data via broadcast
        self.workers_is_dp_head: list[bool] = []
        self.parallel_strategy: ParallelStrategy | None = None

        self._worker_role: str

        self.rollout: RolloutController = None
        self.weight_update_group_initialized = False

    def create_process_group(self, parallel_strategy: ParallelStrategy | None = None):
        """Placeholder method for process group creation.

        This is a dummy method maintained for API compatibility. The actual
        process group creation happens during `initialize()` when engines are
        initialized on workers.

        Parameters
        ----------
        parallel_strategy : ParallelStrategy | None, optional
            Parallel strategy configuration (currently unused), by default None
        """
        pass

    def initialize(
        self,
        role: str,
        alloc_mode: AllocationMode,
        ft_spec: FinetuneSpec,
        **kwargs,
    ):
        """Initialize environments for distributed training and load models.

        Parameters
        ----------
        role : str
            Role identifier for the workers
        alloc_mode : AllocationMode
            Allocation mode configuration for distributed setup
        ft_spec : FinetuneSpec
            Finetune specification for model initialization
        **kwargs
            Additional keyword arguments passed to engine initialization
        """
        # Store configuration
        self._worker_role = role
        self.alloc_mode = alloc_mode

        self.parallel_strategy = alloc_mode.train

        # Create job specification for scheduler
        # Convert scheduling_spec tuple to list for scheduler compatibility
        # The scheduler will handle task replication across workers if needed
        job = Job(
            replicas=alloc_mode.train.world_size,
            tasks=list(self.config.scheduling_spec),
            scheduling_strategy=self.config.scheduling_strategy,
            role=self._worker_role,
        )

        # Create workers via scheduler
        logger.info("Creating workers via scheduler...")
        worker_ids = self.scheduler.create_workers(job=job)
        logger.info(f"Workers created: {worker_ids}")

        # Wait for workers to be ready
        logger.info("Waiting for workers to be ready...")
        self.workers = self.scheduler.get_workers(role=job.role)
        logger.info(f"Workers ready: {[w.id for w in self.workers]}")

        # Determine distributed training master address and port from rank 0 worker
        # These are used for PyTorch distributed initialization across workers
        # Prefer engine_ports[1] if available, fallback to worker_ports[1]
        rank0_worker = self.workers[0]
        if rank0_worker.engine_ports:
            self._master_port = int(rank0_worker.engine_ports[1])
        else:
            self._master_port = int(rank0_worker.worker_ports[1])
        self._master_addr = rank0_worker.ip

        logger.info(
            f"Distributed training: MASTER_ADDR={self._master_addr}, MASTER_PORT={self._master_port}"
        )

        # Construct engine class import path for dynamic loading on workers
        # Workers will import and instantiate the engine class using this path
        engine_class = self.train_engine
        engine_path = f"{engine_class.__module__}.{engine_class.__name__}"

        # Create and initialize engines on workers
        self._run_async_task(self._async_create_engines(engine_path))
        self._run_async_task(self._async_initialize_engines(ft_spec, **kwargs))

        # Identify DP head workers
        self._identify_dp_heads()

        logger.info("TrainController initialization complete")

    def _run_async_task(self, task):
        """Run an async task synchronously."""
        return asyncio.run(task)

    async def _async_create_engines(self, engine_path: str):
        """Create engine instances on all workers. Sets distributed env vars before creation."""
        logger.info("Creating engines on workers...")

        async def _setup_worker(worker: Worker, rank: int):
            env = {
                "RANK": str(rank),
                "WORLD_SIZE": str(len(self.workers)),
                "MASTER_ADDR": str(self._master_addr),
                "MASTER_PORT": str(self._master_port),
                "LOCAL_RANK": "0",  # NOTE: local rank is always 0 while each process use only one GPU
            }
            await self.scheduler.set_worker_env(worker.id, env)
            await self.scheduler.create_engine(
                worker_id=worker.id,
                engine=engine_path,
                config=self.config,
            )

        tasks = [
            _setup_worker(worker, rank) for rank, worker in enumerate(self.workers)
        ]
        await asyncio.gather(*tasks)
        logger.info("Engines created on all workers!")

    async def _async_initialize_engines(self, ft_spec: FinetuneSpec, **kwargs):
        """Initialize engines: create process groups, then load models and setup optimizers."""
        logger.info("Calling engine initialization...")
        # Phase 1: Create process groups for distributed training
        tasks = [
            self.scheduler.async_call_engine(
                worker_id=worker.id,
                method="create_process_group",
                parallel_strategy=self.parallel_strategy,
                _should_bcast=False,
            )
            for worker in self.workers
        ]
        await asyncio.gather(*tasks)
        # Phase 2: Initialize engines (load models, setup optimizers, etc.)
        tasks = [
            self.scheduler.async_call_engine(
                worker_id=worker.id,
                method="initialize",
                ft_spec=ft_spec,
                _should_bcast=False,
                **kwargs,
            )
            for worker in self.workers
        ]
        await asyncio.gather(*tasks)
        logger.info("All engines are initialized!")

    def _identify_dp_heads(self):
        """Query workers to identify DP heads. Stores result in self.workers_is_dp_head."""
        logger.info("Identifying DP head workers...")

        async def _get_dp_head():
            tasks = [
                self.scheduler.async_call_engine(
                    worker_id=worker.id, method="is_data_parallel_head"
                )
                for worker in self.workers
            ]
            return await asyncio.gather(*tasks)

        self.workers_is_dp_head = self._run_async_task(_get_dp_head())

    def destroy(self):
        """Destroy the controller and release GPU memory of models.

        Cleans up all resources including workers, engines, and internal state.
        """
        logger.info("Destroying TrainController...")

        # First destroy engines to release GPU memory
        if self.workers:
            logger.info("Destroying engines on all workers...")
            try:

                async def _destroy_all_engines():
                    tasks = [
                        self.scheduler.async_call_engine(worker.id, "destroy")
                        for worker in self.workers
                    ]
                    await asyncio.gather(*tasks, return_exceptions=True)

                self._run_async_task(_destroy_all_engines())
                logger.info("Engines destroyed")
            except Exception as e:
                logger.error(f"Error destroying engines: {e}")

        # Then delete workers via scheduler
        try:
            logger.info("Deleting all workers...")
            self.scheduler.delete_workers(role=self._worker_role)
            logger.info("Workers deleted")
        except Exception as e:
            logger.error(f"Error deleting workers: {e}")

        # Clear worker lists
        self.workers.clear()
        self.workers_is_dp_head.clear()

        logger.info("TrainController destroyed")

    def _custom_function_call(self, method: str, *args, **kwargs):
        """Dispatch method call to workers: split batches, replicate args, merge results."""
        dp_split_args, dp_split_kwargs = self._dispatch_inputs(*args, **kwargs)
        results = self._run_async_task(
            self._call_with_dispatched_inputs(method, dp_split_args, dp_split_kwargs)
        )
        # Filter to only keep results from DP head workers
        results = [r for idx, r in enumerate(results) if self.workers_is_dp_head[idx]]
        return self._merge_results(results, method)

    async def _async_custom_function_call(self, method: str, *args, **kwargs):
        """Async version of _custom_function_call."""
        dp_split_args, dp_split_kwargs = self._dispatch_inputs(*args, **kwargs)
        results = await self._call_with_dispatched_inputs(
            method, dp_split_args, dp_split_kwargs
        )
        # Filter to only keep results from DP head workers
        results = [r for idx, r in enumerate(results) if self.workers_is_dp_head[idx]]
        return self._merge_results(results, method)

    def _dispatch_inputs(self, *args, **kwargs):
        """Split DistributedBatch across DP groups, replicate other args to all DP heads."""
        split_args = []
        for arg in args:
            if isinstance(arg, DistributedBatch):
                # Split across DP groups
                split_args.append(self._align_batches_with_dp(arg, rebalance=True))
            else:
                # Replicate to all DP heads
                split_args.append([arg] * self.parallel_strategy.dp_size)

        split_kwargs = {}
        for k, v in kwargs.items():
            if isinstance(v, DistributedBatch):
                split_kwargs[k] = self._align_batches_with_dp(v, rebalance=True)
            else:
                split_kwargs[k] = [v] * self.parallel_strategy.dp_size
        return split_args, split_kwargs

    async def _call_with_dispatched_inputs(
        self,
        method: str,
        dp_split_args: list[list[Any]],
        dp_worker_kwargs: list[dict[str, Any]],
    ):
        """Call method on all workers. DP heads get data slices, others get empty args (broadcast via RPC)."""
        tasks = []
        dp_idx = 0
        for idx, worker in enumerate(self.workers):
            if self.workers_is_dp_head[idx]:
                # Get this DP head worker's slice of each argument
                worker_args = [splits[dp_idx] for splits in dp_split_args]
                worker_kwargs = {
                    k: splits[dp_idx] for k, splits in dp_worker_kwargs.items()
                }

                # Convert DistributedBatch to dict for RPC serialization
                # TODO: Consider passing metadata instead of full tensors to reduce
                # network overhead, especially for large batches
                worker_args = [
                    arg.get_data() if isinstance(arg, DistributedBatch) else arg
                    for arg in worker_args
                ]
                worker_kwargs = {
                    k: v.get_data() if isinstance(v, DistributedBatch) else v
                    for k, v in worker_kwargs.items()
                }
                dp_idx += 1
            else:
                # Non-DP-head workers get empty arguments
                # They will receive data via broadcast in RPC server
                worker_args = []
                worker_kwargs = {}

            tasks.append(
                self.scheduler.async_call_engine(
                    worker.id,
                    method,
                    *worker_args,
                    **worker_kwargs,
                )
            )
        return await asyncio.gather(*tasks)

    def _merge_results(self, results, method):
        """Merge results from DP heads: pad tensors to max seq_len, concat dicts, return first for others."""
        first_result = results[0]

        if isinstance(first_result, torch.Tensor):
            # Pad tensors to max sequence length and concatenate along batch dimension
            # Assumes tensor shape is [batch_size, seq_len, ...]
            max_length = max(tensor.shape[1] for tensor in results)
            n_dim = first_result.ndim
            padded_tensors = []
            for tensor in results:
                # Pad format: (pad_left, pad_right) for each dimension from right to left
                # For 2D: (pad_left_seq, pad_right_seq, pad_left_batch, pad_right_batch)
                pad_mode = (
                    (0,) * (2 * (n_dim - 2))
                    + (0, max_length - tensor.shape[1])  # Pad sequence dimension
                    + (0, 0)  # No padding for batch dimension
                )
                padded_tensor = torch.nn.functional.pad(tensor, pad_mode, value=0.0)
                padded_tensors.append(padded_tensor)
            return torch.cat(padded_tensors, dim=0)

        if isinstance(first_result, dict):
            if len(first_result) == 0:
                return DistributedBatchMemory.from_dict({})

            if any(isinstance(v, torch.Tensor) for v in first_result.values()):
                # Check if this looks like a proper batch (has attention_mask)
                # If so, use DistributedBatchMemory.concat which handles padding correctly
                if "attention_mask" in first_result:
                    return DistributedBatchMemory.concat(
                        [DistributedBatchMemory.from_dict(r) for r in results]
                    )
                else:
                    # Simple tensor dict - concatenate tensors along batch dimension
                    merged = {}
                    for key in first_result.keys():
                        if isinstance(first_result[key], torch.Tensor):
                            merged[key] = torch.cat([r[key] for r in results], dim=0)
                        else:
                            # Non-tensor values are assumed to be identical across workers
                            merged[key] = first_result[key]
                    return DistributedBatchMemory.from_dict(merged)

        # For non-tensor, non-dict results, assume they are already synchronized
        # (e.g., scalar statistics that have been all-reduced)
        return first_result

    def _align_batches_with_dp(
        self, input_: DistributedBatch, rebalance=True
    ) -> list[DistributedBatch]:
        """Split batch across DP groups. Uses chunk_by_ffd if rebalance=True, else simple chunking."""
        # Handle empty batch by replicating to all DP groups
        if len(input_.get_data()) == 0:
            return [input_] * self.alloc_mode.train.dp_size

        if rebalance:
            # Use fair distribution based on sequence lengths (first-fit-decreasing)
            inputs = input_.chunk_by_ffd(1, self.alloc_mode.train.dp_size)
        else:
            # Simple sequential chunking
            inputs = input_.chunk(self.alloc_mode.train.dp_size)
        return inputs

    def export_stats(self):
        """Export training statistics from all workers.

        Collects statistics from all workers. The statistics are assumed to be
        already aggregated and synchronized (e.g., via all-reduce operations),
        so only the first result is returned.

        Returns
        -------
        dict[str, Any]
            Training statistics dictionary
        """

        async def _call_all():
            tasks = [
                self.scheduler.async_call_engine(worker.id, "export_stats")
                for worker in self.workers
            ]
            return await asyncio.gather(*tasks)

        results = self._run_async_task(_call_all())
        # Statistics have been aggregated and synchronized across workers
        # All results should be identical, so return the first one
        return results[0]

    # ==================== ENGINE RPC WRAPPERS ====================
    # Note: Methods like train_batch, forward, etc. are not implemented here.
    # They are expected to be called directly via _custom_function_call in
    # specific training scenarios (PPO, SFT, etc.) where the appropriate
    # loss functions and data processing are handled.
    def train(self, mode: bool = True):
        """Set the engine to training mode.

        Parameters
        ----------
        mode : bool, optional
            Whether to set the engine to training mode, by default True

        Returns
        -------
        TrainController
            Returns self for method chaining
        """
        self._custom_function_call("train", mode)
        return self

    def eval(self):
        """Set the engine to evaluation mode.

        This is a convenience method that calls `self.train(False)`.

        Returns
        -------
        TrainController
            Returns self for method chaining
        """
        return self.train(False)

    def set_version(self, version: int):
        """Set the current weight version in the training engine.

        Parameters
        ----------
        version : int
            The weight version number to set
        """
        self._custom_function_call("set_version", version)

    def get_version(self) -> int:
        """Get the current weight version in the training engine.

        Returns
        -------
        int
            The current weight version number
        """
        return self._custom_function_call("get_version")

    def save(self, meta: SaveLoadMeta):
        """Save model weights and optimizer states for later use.

        Parameters
        ----------
        meta : SaveLoadMeta
            Metadata containing information about where and how to save
        """
        self._custom_function_call("save", meta)

    def load(self, meta: SaveLoadMeta):
        """Load model weights and optimizer states from a file.

        Parameters
        ----------
        meta : SaveLoadMeta
            Metadata containing information about where and how to load
        """
        self._custom_function_call("load", meta)

    def step_lr_scheduler(self):
        """Step the learning rate scheduler.

        Since PPO uses minibatch updates, this method should be called periodically
        (e.g., once per PPO step). It is separated from train_batch to allow
        for more flexible learning rate scheduling.
        """
        self._custom_function_call("step_lr_scheduler")

    # ==================== SFT RPC WRAPPERS ====================
    def train_lm(
        self,
        input_: DistributedBatch,
        *args,
        **kwargs,
    ) -> dict[str, float]:
        """Train language model across workers.

        Parameters
        ----------
        input_ : DistributedBatch
            The distributed input data for language model training
        *args
            Additional positional arguments passed to the engine
        **kwargs
            Additional keyword arguments passed to the engine

        Returns
        -------
        dict[str, float]
            Scalar statistics after training
        """
        return self._custom_function_call("train_lm", input_, *args, **kwargs)

    def evaluate_lm(
        self,
        input_: DistributedBatch,
        *args,
        **kwargs,
    ) -> torch.Tensor | None:
        """Evaluate language model across workers.

        Parameters
        ----------
        input_ : DistributedBatch
            The distributed input data for language model evaluation
        *args
            Additional positional arguments passed to the engine
        **kwargs
            Additional keyword arguments passed to the engine

        Returns
        -------
        torch.Tensor or None
            A scalar loss or None
        """
        return self._custom_function_call("evaluate_lm", input_, *args, **kwargs)

    # =================== GRPO ========================================
    def connect_engine(self, rollout: RolloutController, meta: WeightUpdateMeta):
        if self.rollout is not None and self.rollout != rollout:
            logger.warning(
                f"Connected rollout controller changed from {self.rollout} to {rollout}."
            )
        self.rollout = rollout

        if (
            meta.type == current_platform.communication_backend
            and not self.weight_update_group_initialized
        ):
            self._init_weight_update_from_distributed(meta)
            self.weight_update_group_initialized = True

    def prepare_batch(
        self,
        dataloader: StatefulDataLoader,
        workflow: str,
        workflow_kwargs: dict[str, Any],
        should_accept_fn: str | None = None,
    ) -> DistributedBatch:
        return self.rollout.prepare_batch(
            dataloader=dataloader,
            workflow=workflow,
            workflow_kwargs=workflow_kwargs,
            should_accept_fn=should_accept_fn,
        )

    def rollout_batch(
        self,
        data: list[dict[str, Any]],
        workflow: RolloutWorkflow | type[RolloutWorkflow] | str,
        workflow_kwargs: dict[str, Any],
        should_accept_fn: str | None = None,
    ) -> DistributedBatch:
        return self.rollout.rollout_batch(
            data=data,
            workflow=workflow,
            workflow_kwargs=workflow_kwargs,
            should_accept_fn=should_accept_fn,
        )

    def compute_logp(
        self,
        *args,
        **kwargs,
    ):
        """Compute log probabilities across workers.

        Parameters
        ----------
        *args
            Positional arguments passed to the engine
        **kwargs
            Keyword arguments passed to the engine

        Returns
        -------
        Any
            Log probabilities computed by the engine
        """
        return self._custom_function_call("compute_logp", *args, **kwargs)

    def compute_advantages(
        self,
        *args,
        **kwargs,
    ):
        """Compute advantages across workers.

        Parameters
        ----------
        *args
            Positional arguments passed to the engine
        **kwargs
            Keyword arguments passed to the engine

        Returns
        -------
        Any
            Advantages computed by the engine
        """
        return self._custom_function_call("compute_advantages", *args, **kwargs)

    def ppo_update(
        self,
        input_: DistributedBatch,
    ) -> dict[str, float]:
        """Perform PPO update step with the given batch.

        Parameters
        ----------
        input_ : DistributedBatch
            The distributed input data containing trajectories for PPO update

        Returns
        -------
        Dict[str, float]
            Scalar statistics after PPO update
        """
        return self._custom_function_call("ppo_update", input_)

    def _init_weight_update_from_distributed(self, meta: WeightUpdateMeta):
        raise NotImplementedError()

    def _update_weights_from_distributed(self, meta: WeightUpdateMeta):
        raise NotImplementedError()

    def _update_weights_from_disk(self, meta: WeightUpdateMeta):
        # Update all LocalInfEngine's local weight
        self.save(
            SaveLoadMeta(
                path=meta.path,
                weight_format="hf",
                with_optim=False,
                tokenizer=None,
                processor=None,
            )
        )
        has_model_files = any(child.is_file() for child in Path(meta.path).iterdir())
        assert has_model_files, f"No model files found in {meta.path} after saving."

        update_name = names.update_weights_from_disk(
            self.config.experiment_name,
            self.config.trial_name,
            self.get_version(),
        )
        name_resolve.add(
            update_name,
            str(datetime.now().timestamp()),
            keepalive_ttl=120,
            replace=True,
        )

        meta.clear_checkpoint = False
        self._run_async_task(self.rollout.update_weights_from_disk(meta))
        shutil.rmtree(meta.path, ignore_errors=True)

    def _check_rollout_engine_connected(self):
        """Validate that rollout engine has been connected via connect_engine()."""
        if self.rollout is None:
            raise RuntimeError(
                "Rollout engine not connected. Call connect_engine()"
                " before using rollout/update_weight methods."
            )

    def update_weights(self, meta: WeightUpdateMeta):
        self._check_rollout_engine_connected()
        if meta.type == current_platform.communication_backend:
            assert self.weight_update_group_initialized
            self._update_weights_from_distributed(meta)
        elif meta.type == "disk":
            self._update_weights_from_disk(meta)
        else:
            raise ValueError(f"Unknown weight update type {meta.type}")
